package com.lw.leetcode.tree.c;

import java.util.*;

/**
 * Created with IntelliJ IDEA.
 * 1617. 统计子树中城市之间最大距离
 *
 * @author liw
 * @version 1.0
 * @date 2022/8/25 11:02
 */
public class CountSubgraphsForEachDiameter {

    public static void main(String[] args) {
        CountSubgraphsForEachDiameter test = new CountSubgraphsForEachDiameter();

        // {3,4,0}
        int n = 4;
        int[][] edges = {{1, 2}, {2, 3}, {2, 4}};

        // {1}
//        int n = 2;
//        int[][] edges = {{1, 2}};

        // {2,1}
//        int n = 3;
//        int[][] edges = {{1, 2}, {2, 3}};

        int[] ints = test.countSubgraphsForEachDiameter(n, edges);
        System.out.println(Arrays.toString(ints));
    }

    private int[] flags;
    private Map<Integer, List<Integer>> map = new HashMap<>();

    public int[] countSubgraphsForEachDiameter(int n, int[][] edges) {
        int length = 1 << n;
        this.flags = new int[n];
        int[] counts = new int[n - 1];
        for (int[] edge : edges) {
            map.computeIfAbsent(edge[0] - 1, v -> new ArrayList<>()).add(edge[1] - 1);
            map.computeIfAbsent(edge[1] - 1, v -> new ArrayList<>()).add(edge[0] - 1);
        }
        int t = -1;
        a:
        for (int i = 3; i < length; i++) {
            if ((i & (i - 1)) == 0) {
                continue;
            }
//            System.out.println(Integer.toBinaryString(i));
            Arrays.fill(flags, 0);
            for (int j = 0; j < n; j++) {
                if ((i | (1 << j)) == i) {
                    flags[j] = 1;
                    t = j;
                }
            }
            check(t);
            for (int flag : flags) {
                if (flag == 1) {
                    continue a;
                }
            }
            counts[findCount(findMax(t) & 0XFF) - 2]++;
        }
        return counts;
    }

    private int findCount(int t) {
        flags[t] = 4;
        List<Integer> list = map.get(t);
        int r = 0;
        for (Integer v : list) {
            if (flags[v] == 0 || flags[v] == 4) {
                continue;
            }
            r = Math.max(r, findCount(v));
        }
        return r + 1;
    }

    private int findMax(int t) {
        flags[t] = 3;
        List<Integer> list = map.get(t);
        int r = t;
        for (Integer v : list) {
            if (flags[v] == 0 || flags[v] == 3) {
                continue;
            }
            r = Math.max(r, findMax(v) + (1 << 8));
        }
        return r;
    }

    private void check(int t) {
        flags[t] = 2;
        List<Integer> list = map.get(t);
        for (Integer v : list) {
            if (flags[v] == 0 || flags[v] == 2) {
                continue;
            }
            check(v);
        }
    }

}
